//! 进程间传输的实现（使用命名管道）

use bytes::{Buf, BufMut, Bytes, BytesMut};
use flume::{
    bounded,
    r#async::{RecvStream as AsyncRecvStream, SendSink as AsyncSendSink},
    Receiver, RecvError, Sender, TryRecvError,
};
use futures_lite::Stream;
use futures_sink::Sink;
use pin_project::pin_project;
use quic_rpc::{
    transport::{ConnectionErrors, Connector, Listener, LocalAddr, StreamTypes},
    RpcMessage,
};
use serde::{Deserialize, Serialize};
use std::{
    error::Error,
    fmt::{Debug, Display, Formatter, Result as FmtResult},
    future::Future,
    io::{Error as IoError, ErrorKind, Result as IoResult},
    marker::PhantomData,
    pin::Pin,
    task::{Context, Poll},
};
use tokio::{
    io::{AsyncRead, AsyncWrite, ReadBuf},
    net::windows::named_pipe::{ClientOptions, ServerOptions},
    task::JoinHandle,
    time::{sleep, Duration},
};
use tracing::error;

const MAX_CHANNEL_CAPACITY: usize = 32;

#[pin_project]
pub struct SendSink<Out: 'static>(#[pin] AsyncSendSink<'static, Out>);

impl<Out> Debug for SendSink<Out> {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        f.debug_struct("SendSink").finish()
    }
}

impl<Out> Sink<Out> for SendSink<Out> {
    type Error = IoError;

    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Pin::new(&mut self.project().0)
            .poll_ready(cx)
            .map_err(|e| IoError::new(ErrorKind::NetworkUnreachable, format!("Can't send {}", e)))
    }

    fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
        Pin::new(&mut self.project().0)
            .start_send(item)
            .map_err(|e| IoError::new(ErrorKind::NetworkUnreachable, format!("Can't send {}", e)))
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Pin::new(&mut self.project().0)
            .poll_flush(cx)
            .map_err(|e| IoError::new(ErrorKind::NetworkUnreachable, format!("Can't send {}", e)))
    }

    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Pin::new(&mut self.project().0)
            .poll_close(cx)
            .map_err(|e| IoError::new(ErrorKind::NetworkUnreachable, format!("Can't send {}", e)))
    }
}

#[pin_project]
pub struct RecvStream<In: 'static>(#[pin] AsyncRecvStream<'static, In>);

impl<In> Debug for RecvStream<In> {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("RecvStream").finish()
    }
}

impl<In> Stream for RecvStream<In> {
    type Item = Result<In, IoError>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match Pin::new(&mut self.project().0).poll_next(cx) {
            Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
            Poll::Ready(None) => Poll::Ready(None),
            Poll::Pending => Poll::Pending,
        }
    }
}

#[pin_project]
struct PipeStreamProxyTask<I, O, T> {
    /// 命名管道对象
    #[pin]
    stream: T,
    /// 用于序列化的缓冲区
    #[pin]
    buffer_serialize: Option<BytesMut>,
    /// 用于从管道流读取数据的缓冲区
    #[pin]
    buffer_read: Option<BytesMut>,
    /// 用于把数据写入管道流的缓冲区
    #[pin]
    buffer_write: Option<Bytes>,
    /// 管道写入操作的关闭状态
    #[pin]
    status_shutdown: Poll<()>,
    /// 从管道中接收到数据并反序列化后，发送给指定的Receiver
    #[pin]
    tx_inner: Option<Sender<I>>,
    /// 来自Sender的数据，即将序列化写入管道流
    #[pin]
    rx_inner: Option<Receiver<O>>,
}

impl<I, O, T> PipeStreamProxyTask<I, O, T> {
    fn new(stream: T, tx: Sender<I>, rx: Receiver<O>) -> Self {
        Self {
            stream,
            buffer_serialize: None,
            buffer_read: None,
            buffer_write: None,
            status_shutdown: Poll::Pending,
            tx_inner: Some(tx),
            rx_inner: Some(rx),
        }
    }

    fn serialize<S: Serialize>(
        buffer: Pin<&mut Option<BytesMut>>,
        data: &S,
    ) -> Result<Bytes, postcard::Error> {
        let buffer = buffer.get_mut();
        let buf = buffer.take().unwrap_or_default();
        let mut buf = postcard::to_io(data, buf.writer())?.into_inner();
        if buf.len() <= 1024 {
            let res = buf.split().freeze();
            buffer.replace(buf);
            Ok(res)
        } else {
            Ok(buf.freeze())
        }
    }

    fn deserialize<'a, D: Deserialize<'a>>(data: &'a [u8]) -> Result<D, postcard::Error> {
        postcard::from_bytes(data)
    }

    fn write(
        buffer_serialize: Pin<&mut Option<BytesMut>>,
        buffer_write: Pin<&mut Option<Bytes>>,
        data: Option<O>,
    ) where
        O: Serialize,
    {
        if let Ok(buf) = Self::serialize(buffer_serialize, &data) {
            let mut buf2 = BytesMut::with_capacity(size_of::<u16>() + buf.len());
            buf2.put_u16(buf.len() as _);
            buf2.put_slice(&buf);
            buffer_write.get_mut().replace(buf2.freeze());
        }
    }

    fn read(buffer_read: Pin<&mut Option<BytesMut>>, data: &[u8]) -> Result<Option<I>, IoError>
    where
        for<'a> I: Deserialize<'a>,
    {
        let buffer = buffer_read.get_mut();
        let mut buf = buffer.take().map_or_else(
            || {
                if data.len() < size_of::<u16>() {
                    Err(IoError::new(ErrorKind::InvalidData, "Length not enough."))
                } else {
                    Ok(Default::default())
                }
            },
            Ok,
        )?;
        buf.put_slice(data);
        let src = buf.clone();
        let len = buf.get_u16() as _;
        let data = if len > buf.len() {
            buffer.replace(src);
            return Err(IoError::new(ErrorKind::InvalidData, "Length not enough."));
        } else if len < buf.len() {
            buffer.replace(BytesMut::from(&buf[len..]));
            &buf[..len]
        } else {
            &buf[..len]
        };
        Self::deserialize(data).map_err(|e| IoError::new(ErrorKind::InvalidData, e))
    }
}

impl<I, O, T> Future for PipeStreamProxyTask<I, O, T>
where
    for<'a> I: Deserialize<'a>,
    O: Serialize,
    T: AsyncRead + AsyncWrite + Sync + Send + Unpin,
{
    type Output = ();

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut p = self.project();

        // 第一步： 如果有准备写入管道的数据，则处理并写入
        if let Some(buf) = p.buffer_write.take() {
            match Pin::new(&mut p.stream).poll_write(cx, &buf) {
                Poll::Ready(Ok(n)) if n < buf.len() => {
                    *p.buffer_write = Some(Bytes::copy_from_slice(&buf[n..]))
                }
                Poll::Pending => *p.buffer_write = Some(buf),
                _ => (),
            }
        }

        // 第二部： 如果当前写入缓冲区没有数据，则生产数据
        if p.buffer_write.is_none() {
            if let Some(rx) = p.rx_inner.clone() {
                match rx.try_recv() {
                    Ok(data) => Self::write(p.buffer_serialize, p.buffer_write, Some(data)),
                    Err(TryRecvError::Disconnected) if p.status_shutdown.is_pending() => {
                        Self::write(p.buffer_serialize, p.buffer_write, None);
                        *p.status_shutdown = Pin::new(&mut p.stream).poll_shutdown(cx).map(|_| ());
                        drop(p.rx_inner.take());
                    }
                    _ => (),
                };
            } else if p.status_shutdown.is_pending() {
                *p.status_shutdown = Pin::new(&mut p.stream).poll_shutdown(cx).map(|_| ());
            }
        }

        // 第三部： 读取管道的数据到缓冲区并尝试处理
        let mut buffer = [0u8; 32];
        let mut buf = ReadBuf::new(&mut buffer);
        match Pin::new(&mut p.stream).poll_read(cx, &mut buf) {
            Poll::Ready(Ok(_)) => if buf.filled().len() > 0 {
                if let Ok(data) = Self::read(p.buffer_read, buf.filled()) {
                    match data {
                        Some(data) => {
                            if let Some(tx) = p.tx_inner.clone() {
                                if tx.send(data).is_err() {
                                    drop(p.tx_inner.take());
                                };
                            }
                        }
                        None => drop(p.tx_inner.take()),
                    }
                }
            } else {
                // 按照AsyncRead的规定，如果Ready(Ok(()))且filled length是0，则读取通道已经不可用（达到EOF）
                drop(p.tx_inner.take())
            }
            Poll::Ready(Err(e)) => {
                error!(?e, "Reading error.");
                drop(p.tx_inner.take())
            }
            _ => (),
        }

        // 第四部： 检查任务是否已经完成
        if p.rx_inner.is_none() && p.tx_inner.is_none() && p.status_shutdown.is_ready() {
            return Poll::Ready(());
        }

        cx.waker().wake_by_ref();
        Poll::Pending
    }
}

/// 基于管道的监听器。
#[derive(Debug)]
pub struct PipeListener<In: RpcMessage, Out: RpcMessage> {
    accepting: Receiver<(AsyncSendSink<'static, Out>, AsyncRecvStream<'static, In>)>,
    handle: Option<JoinHandle<Result<(), IoError>>>,
}

impl<In: RpcMessage, Out: RpcMessage> PipeListener<In, Out> {
    pub fn serve(pipe_name: &str) -> IoResult<Self> {
        let (accepting_tx, accepting) = bounded(32);
        let addr = pipe_name.to_string();
        let handle = tokio::spawn(async move {
            loop {
                let server = ServerOptions::new().create(&addr)?;

                server.connect().await?;
                let (tx, rx_inner) = bounded(MAX_CHANNEL_CAPACITY);
                let (tx_inner, rx) = bounded(MAX_CHANNEL_CAPACITY);

                tokio::spawn(PipeStreamProxyTask::new(server, tx_inner, rx_inner));
                accepting_tx
                    .send_async((tx.into_sink(), rx.into_stream()))
                    .await
                    .map_err(|e| IoError::new(ErrorKind::NetworkUnreachable, e))?;
            }
        });

        Ok(Self {
            accepting,
            handle: Some(handle),
        })
    }

    pub fn stop(&mut self) {
        if let Some(handle) = self.handle.take() {
            handle.abort();
        }
    }
}

impl<In: RpcMessage, Out: RpcMessage> Drop for PipeListener<In, Out> {
    fn drop(&mut self) {
        self.stop()
    }
}

impl<In: RpcMessage, Out: RpcMessage> Clone for PipeListener<In, Out> {
    fn clone(&self) -> Self {
        Self {
            accepting: self.accepting.clone(),
            handle: None,
        }
    }
}

impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for PipeListener<In, Out> {
    type SendError = IoError;
    type RecvError = IoError;
    type OpenError = self::OpenError;
    type AcceptError = self::AcceptError;
}

impl<In: RpcMessage, Out: RpcMessage> StreamTypes for PipeListener<In, Out> {
    type In = In;
    type Out = Out;
    type SendSink = SendSink<Out>;
    type RecvStream = RecvStream<In>;
}

impl<In: RpcMessage, Out: RpcMessage> Listener for PipeListener<In, Out> {
    #[allow(refining_impl_trait)]
    async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
        let (sink, stream) = self
            .accepting
            .recv_async()
            .await
            .map_err(AcceptError::RecvError)?;
        Ok((SendSink(sink), RecvStream(stream)))
    }

    fn local_addr(&self) -> &[LocalAddr] {
        &[LocalAddr::Mem]
    }
}

/// 基于命名管道的连接器。
#[derive(Debug)]
pub struct PipeConnector<In: RpcMessage, Out: RpcMessage> {
    addr: String,
    _p: PhantomData<(In, Out)>,
}

impl<In: RpcMessage, Out: RpcMessage> PipeConnector<In, Out> {
    pub fn new(pipe_name: &str) -> IoResult<Self> {
        Ok(Self {
            addr: pipe_name.to_string(),
            _p: Default::default(),
        })
    }
}

impl<In: RpcMessage, Out: RpcMessage> Clone for PipeConnector<In, Out> {
    fn clone(&self) -> Self {
        Self {
            addr: self.addr.clone(),
            _p: Default::default(),
        }
    }
}

impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for PipeConnector<In, Out> {
    type SendError = IoError;
    type RecvError = IoError;
    type OpenError = self::OpenError;
    type AcceptError = self::AcceptError;
}

impl<In: RpcMessage, Out: RpcMessage> StreamTypes for PipeConnector<In, Out> {
    type In = In;
    type Out = Out;
    type SendSink = SendSink<Out>;
    type RecvStream = RecvStream<In>;
}

impl<In: RpcMessage, Out: RpcMessage> Connector for PipeConnector<In, Out> {
    #[allow(refining_impl_trait)]
    async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
        let client = loop {
            match ClientOptions::new().open(&self.addr) {
                Ok(client) => break client,
                _ => sleep(Duration::from_millis(5)).await,
            }
        };

        let (tx, rx_inner) = bounded(MAX_CHANNEL_CAPACITY);
        let (tx_inner, rx) = bounded(MAX_CHANNEL_CAPACITY);
        tokio::spawn(PipeStreamProxyTask::new(client, tx_inner, rx_inner));
        Ok((SendSink(tx.into_sink()), RecvStream(rx.into_stream())))
    }
}

/// 接受客户端时的错误类型
#[derive(Debug)]
pub enum AcceptError {
    RecvError(RecvError),
}

impl Display for AcceptError {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        Debug::fmt(self, f)
    }
}

impl Error for AcceptError {}

/// 客户端连接时的错误类型
#[derive(Debug)]
pub enum OpenError {
    Io(IoError),
}

impl Display for OpenError {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        Debug::fmt(self, f)
    }
}

impl Error for OpenError {}
